{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PaddleSlim Distillation知识蒸馏简介与实验\n",
    "\n",
    "一般情况下，模型参数量越多，结构越复杂，其性能越好，但参数也越冗余，运算量和资源消耗也越大。**知识蒸馏**就是一种将大模型学习到的有用信息（Dark Knowledge）压缩进更小更快的模型，而获得可以匹敌大模型结果的方法。\n",
    "\n",
    "在本文中性能强劲的大模型被称为teacher, 性能稍逊但体积较小的模型被称为student。示例包含以下步骤：\n",
    "\n",
    "1. 导入依赖\n",
    "2. 定义student_program和teacher_program\n",
    "3. 选择特征图\n",
    "4. 合并program (merge)并添加蒸馏loss\n",
    "5. 模型训练\n",
    "\n",
    "\n",
    "## 1. 导入依赖\n",
    "PaddleSlim依赖Paddle1.7版本，请确认已正确安装Paddle，然后按以下方式导入Paddle、PaddleSlim以及其他依赖:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "import paddleslim as slim\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 定义student_program和teacher_program\n",
    "\n",
    "本教程在MNIST数据集上进行知识蒸馏的训练和验证，输入图片尺寸为`[1, 28, 28]`，输出类别数为10。\n",
    "选择`ResNet50`作为teacher对`MobileNet`结构的student进行蒸馏训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = models.__dict__['MobileNet']()\n",
    "student_program = fluid.Program()\n",
    "student_startup = fluid.Program()\n",
    "with fluid.program_guard(student_program, student_startup):\n",
    "    image = fluid.data(\n",
    "        name='image', shape=[None] + [1, 28, 28], dtype='float32')\n",
    "    label = fluid.data(name='label', shape=[None, 1], dtype='int64')\n",
    "    out = model.net(input=image, class_dim=10)\n",
    "    cost = fluid.layers.cross_entropy(input=out, label=label)\n",
    "    avg_cost = fluid.layers.mean(x=cost)\n",
    "    acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)\n",
    "    acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher_model = models.__dict__['ResNet50']()\n",
    "teacher_program = fluid.Program()\n",
    "teacher_startup = fluid.Program()\n",
    "with fluid.program_guard(teacher_program, teacher_startup):\n",
    "    with fluid.unique_name.guard():\n",
    "        image = fluid.data(\n",
    "            name='image', shape=[None] + [1, 28, 28], dtype='float32')\n",
    "        predict = teacher_model.net(image, class_dim=10)\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "exe.run(teacher_startup)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 选择特征图\n",
    "我们可以用student_的list_vars方法来观察其中全部的Variables，从中选出一个或多个变量（Variable）来拟合teacher相应的变量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get all student variables\n",
    "student_vars = []\n",
    "for v in student_program.list_vars():\n",
    "    student_vars.append((v.name, v.shape))\n",
    "#uncomment the following lines to observe student's variables for distillation\n",
    "#print(\"=\"*50+\"student_model_vars\"+\"=\"*50)\n",
    "#print(student_vars)\n",
    "\n",
    "# get all teacher variables\n",
    "teacher_vars = []\n",
    "for v in teacher_program.list_vars():\n",
    "    teacher_vars.append((v.name, v.shape))\n",
    "#uncomment the following lines to observe teacher's variables for distillation\n",
    "#print(\"=\"*50+\"teacher_model_vars\"+\"=\"*50)\n",
    "#print(teacher_vars)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "经过筛选我们可以看到，teacher_program中的'bn5c_branch2b.output.1.tmp_3'和student_program的'depthwise_conv2d_11.tmp_0'尺寸一致，可以组成蒸馏损失函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 合并program (merge)并添加蒸馏loss\n",
    "merge操作将student_program和teacher_program中的所有Variables和Op都将被添加到同一个Program中，同时为了避免两个program中有同名变量会引起命名冲突，merge也会为teacher_program中的Variables添加一个同一的命名前缀name_prefix，其默认值是'teacher_'\n",
    "\n",
    "为了确保teacher网络和student网络输入的数据是一样的，merge操作也会对两个program的输入数据层进行合并操作，所以需要指定一个数据层名称的映射关系data_name_map，key是teacher的输入数据名称，value是student的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_name_map = {'image': 'image'}\n",
    "main = slim.dist.merge(teacher_program, student_program, data_name_map, fluid.CPUPlace())\n",
    "with fluid.program_guard(student_program, student_startup):\n",
    "    l2_loss = slim.dist.l2_loss('teacher_bn5c_branch2b.output.1.tmp_3', 'depthwise_conv2d_11.tmp_0', student_program)\n",
    "    loss = l2_loss + avg_cost\n",
    "    opt = fluid.optimizer.Momentum(0.01, 0.9)\n",
    "    opt.minimize(loss)\n",
    "exe.run(student_startup)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 模型训练\n",
    "\n",
    "为了快速执行该示例，我们选取简单的MNIST数据，Paddle框架的`paddle.dataset.mnist`包定义了MNIST数据的下载和读取。\n",
    "代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_reader = paddle.batch(\n",
    "    paddle.dataset.mnist.train(), batch_size=128, drop_last=True)\n",
    "train_feeder = fluid.DataFeeder(['image', 'label'], fluid.CPUPlace(), student_program)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for data in train_reader():\n",
    "    acc1, acc5, loss_np = exe.run(student_program, feed=train_feeder.feed(data), fetch_list=[acc_top1.name, acc_top5.name, loss.name])\n",
    "    print(\"Acc1: {:.6f}, Acc5: {:.6f}, Loss: {:.6f}\".format(acc1.mean(), acc5.mean(), loss_np.mean()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
